package com.novel.common.dao.impl;

import com.mongodb.BasicDBObject;
import com.mongodb.DBObject;
import com.novel.common.dao.BaseDao;
import com.novel.common.entity.Entity;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Example;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.convert.MongoWriter;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;

import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Objects;

/**
 * @author 奔波儿灞
 * @since 1.0
 */
public abstract class BaseDaoImpl<T extends Entity> implements BaseDao<T> {

    private static final String SET_KEY = "$set";

    @Autowired
    private MongoTemplate mongoTemplate;

    @Autowired
    private MongoConverter mongoConverter;

    private Class<T> entityClazz;

    @SuppressWarnings({"unchecked"})
    public BaseDaoImpl() {
        Type genericType = getClass().getGenericSuperclass();
        Type[] params = ((ParameterizedType) genericType).getActualTypeArguments();
        entityClazz = (Class) params[0];
    }

    @Override
    public T findById(String id) {
        return mongoTemplate.findById(id, entityClazz);
    }

    @Override
    public List<T> findByIds(List<String> ids) {
        return mongoTemplate.find(query(Criteria.where(Entity.ID_KEY).in(ids)), entityClazz);
    }

    @Override
    public T findOne(Query query) {
        return mongoTemplate.findOne(query, entityClazz);
    }

    @Override
    public List<T> find(Query query) {
        return mongoTemplate.find(query, entityClazz);
    }

    @Override
    public List<T> find(T entity) {
        Example<T> example = Example.of(entity);
        Query query = Query.query(Criteria.byExample(example));
        return find(query);
    }

    @Override
    public List<T> findAll() {
        return mongoTemplate.findAll(entityClazz);
    }

    @Override
    public Long count(Query query) {
        return mongoTemplate.count(query, entityClazz);
    }

    @Override
    public void insert(T entity) {
        setDate(entity);
        mongoTemplate.insert(entity);
    }

    @Override
    public void batchInsert(List<T> entities) {
        entities.forEach(this::setDate);
        mongoTemplate.insert(entities, entityClazz);
    }

    private void setDate(Entity entity) {
        if (Objects.isNull(entity.getCreateAt())) {
            entity.setCreateAt(new Date());
        }
        if (Objects.isNull(entity.getUpdateAt())) {
            entity.setUpdateAt(new Date());
        }
    }

    @Override
    public void updateById(T entity) {
        doUpdate(entity, true);

    }

    @Override
    public void updateNotNullById(T entity) {
        doUpdate(entity, false);
    }

    private void doUpdate(T entity, boolean overwrite) {
        // 更新时，设置修改时间
        entity.setUpdateAt(new Date());
        DBObject dbObject = toDbObject(entity, mongoConverter);
        Query query = query(Criteria.where(Entity.ID_KEY).is(entity.getId()));
        Update update;
        if (overwrite) {
            // 更新时，排除创建时间字段
            update = Update.fromDBObject(dbObject, Entity.CREATE_KEY);

        } else {
            // 更新时，排除创建时间字段，只更新不为null的字段
            update = Update.fromDBObject(new BasicDBObject().append(SET_KEY, dbObject), Entity.CREATE_KEY);
        }
        mongoTemplate.updateFirst(query, update, entityClazz);
    }

    @Override
    public void update(Query query, Update update) {
        Object updateAt = update.getUpdateObject().get(Entity.UPDATE_KEY);
        if (Objects.isNull(updateAt)) {
            update.set(Entity.UPDATE_KEY, new Date());
        }
        mongoTemplate.updateMulti(query, update, entityClazz);
    }

    @Override
    public void deleteById(String id) {
        mongoTemplate.remove(query(Criteria.where(Entity.ID_KEY).is(id)), entityClazz);
    }

    @Override
    public void deleteByIds(List<String> ids) {
        mongoTemplate.remove(query(Criteria.where(Entity.ID_KEY).in(ids)), entityClazz);
    }

    @Override
    public void delete(Query query) {
        mongoTemplate.remove(query, entityClazz);
    }

    @Override
    public Page<T> page(Criteria criteria, Pageable pageable) {
        Query query = Query.query(criteria);
        Long count = count(query);
        if (count == 0L) {
            return new PageImpl<>(Collections.emptyList(), pageable, count);
        }
        List<T> list = find(query.with(pageable));
        return new PageImpl<>(list, pageable, count);
    }

    protected Query query(Criteria criteria) {
        return Query.query(criteria);
    }

    protected MongoTemplate getMongoTemplate() {
        return mongoTemplate;
    }

    /**
     * 对象转为DBObject
     *
     * @param objectToSave T
     * @param writer MongoWriter
     * @return DBObject
     */
    protected <T> DBObject toDbObject(T objectToSave, MongoWriter<T> writer) {
        DBObject dbDoc = new BasicDBObject();
        writer.write(objectToSave, dbDoc);
        return dbDoc;
    }
}

